Conversation
📝 WalkthroughWalkthroughThe changes strengthen type annotations across acquisition function interfaces and update the Bayesian optimization public API. AcquisitionFunction's Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #600 +/- ##
=======================================
Coverage 97.78% 97.78%
=======================================
Files 10 10
Lines 1221 1221
=======================================
Hits 1194 1194
Misses 27 27 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
till-m
left a comment
There was a problem hiding this comment.
LGTM, but I could you have a look at the one comment I added?
| "state": state_dict["state"]["key"].tolist(), | ||
| "pos": state_dict["state"]["pos"], | ||
| "has_gauss": state_dict["has_gauss"], | ||
| "cached_gaussian": state_dict["gauss"], |
There was a problem hiding this comment.
Why don't the keys line up here ("cached_gaussian" vs "gauss")?
There was a problem hiding this comment.
I don't fully understand the meaning of each value, nor do I know the exact intention behind each method. Therefore, I preserved the keys as they were to avoid causing issues in other code that consumes this function's return value. I only changed the internal logic that depended on legacy.
My primary concern was errors occurring for users who were already using this method. So my goal was to maintain the existing return structure. When legacy=true, get_state returned a value with the following structure,
(
"MT19937",
<numpy.ndarray>,
int, # like 623
int, # like 0
float # like 0.0
)and when legacy=false, get_state returned a value with the following structure.
{
"bit_generator": "MT19937",
"state": {
"key": <numpy.ndarray>,
"pos": int # like 623
},
"has_gauss": int, # like 0
"gauss": float # like 0.0
}For that reason, I assumed gauss and cached_gauss were used with the same meaning. Using gauss instead of cached_gauss to align the numpy return value with the keys is also a good idea. However, it should be announced as a breaking change.
sample code
from pprint import pprint
from bayes_opt.bayesian_optimization import BayesianOptimization
bayes = BayesianOptimization(None, {})
legacy = bayes._random_state.get_state(legacy=True)
non_legacy = bayes._random_state.get_state(legacy=False)
print("Legacy state:")
pprint(legacy)
print("Non-legacy state:")
pprint(non_legacy)output
Legacy state:
('MT19937',
array([2147483648, 2075331599, 140858681, 3526623561, 2541108888,
2106811263, 3519418634, 407860018, 2249244654, 2606184075,
2786589483, 1437634829, 3069487802, 2325528976, 2221173448,
4175430749, 60802753, 3831120806, 3576967720, 1467437066,
3102226502, 1109303602, 4202242805, 3948013612, 3365853984,
1662141710, 665954637, 3930982131, 573623358, 3534123242,
1163249977, 1484804157, 2526724740, 4041334237, 303997122,
467039403, 2604812076, 2662108352, 2700590779, 3149658310,
475663908, 2698417034, 1437811983, 2453274244, 3934471757,
1331634643, 3476886289, 1214548185, 2557676973, 841994067,
1486750722, 1743267989, 2558066518, 1577953441, 2313272813,
553286535, 1808628211, 2154155001, 1626207377, 3930160892,
2129127017, 2082629614, 3872620351, 2201093007, 2516576609,
1033987357, 2408204758, 15660947, 104784592, 59865213,
186913512, 2677884128, 2060217906, 2669113803, 2011873750,
208427217, 3680958504, 2179926499, 3166739308, 1431064049,
2975913278, 1687432296, 1015884945, 1679512422, 4236327727,
789327648, 2432065914, 3711318603, 2856411391, 3648323176,
2833858325, 2966091895, 3017587410, 2267289704, 1257721755,
1729437608, 1424227325, 2923596926, 1826322183, 368372361,
1488204762, 1945476249, 562222798, 493888013, 2257447923,
378468050, 2023039273, 3634128929, 2212812842, 3736007970,
3181848715, 3253178714, 2247387434, 630298605, 168071064,
2453990713, 2754709023, 158309059, 1693133287, 425567776,
610244961, 1099495700, 1651920650, 4001234624, 4215841990,
1597420538, 1371783983, 782335516, 955314030, 2124985108,
1230992899, 2554470207, 178356559, 1538520927, 3123923199,
98816670, 19339993, 1211323375, 2192833935, 3550014245,
4079832419, 763245682, 3114649865, 637250886, 2346230918,
1613897166, 900697267, 717509049, 3148396806, 860799586,
2074079507, 3061186827, 2657283036, 1289050313, 224279970,
2147563708, 1342253071, 1099805121, 1986459658, 451923538,
4151091570, 4093933368, 1531334554, 2482301065, 3017676208,
1552415979, 4080116605, 1391877515, 3326970419, 2087771436,
881497654, 409221337, 2336918916, 3495750864, 505102391,
87044493, 1374578224, 2345228469, 2677354761, 1686087426,
350682568, 3896700142, 3387441124, 1322844580, 2035232113,
4171743547, 2761364203, 87729543, 1543233909, 2006521429,
1975119445, 37244549, 2381364078, 1787348020, 1892866190,
4158911646, 2832543291, 2923372499, 933489775, 1999573074,
1532857763, 3880893325, 2768633145, 1953348816, 3266441215,
2573231980, 2998278938, 2040706335, 3623219627, 3178406798,
4283796809, 1300309756, 1758157157, 4250966860, 2653105583,
3650054898, 2689840365, 3155536137, 1503944792, 1147189469,
3536335302, 3769476577, 2815860922, 4078024620, 2513323216,
1902584155, 4129754107, 4152359938, 226115604, 851253201,
4280862308, 180287939, 2460657108, 2761413323, 1164459732,
4138592784, 1128464012, 2058279714, 28513951, 138939637,
1445273851, 2902445187, 1819904997, 4126426761, 2232230179,
1902241997, 1982429477, 3030253536, 2390946236, 836566848,
2165522031, 2843209162, 1133455873, 547623528, 207067323,
3740164691, 2956947873, 3885177903, 3602466789, 2265740401,
2685171170, 2058459229, 95039846, 2428394435, 4085737081,
1250884434, 2137775703, 3615871800, 736326573, 1908407987,
3232375302, 4025275217, 3344578052, 419050246, 539900722,
361845484, 2024236041, 1807301781, 2207338269, 1752779524,
1862418620, 1495823156, 1133973867, 1296654836, 2940007730,
949464795, 1328598876, 2353612020, 1256954613, 2085080443,
3735562356, 4031242042, 2376353226, 1227464517, 3957619112,
1276759908, 169813089, 2554310499, 2362132599, 734463866,
2567295435, 3713429466, 1660269273, 2342554397, 1944535609,
2590579477, 2663485021, 3446727181, 1776368335, 4209700934,
325303151, 1020536745, 1713248288, 1413553117, 1011704881,
3010455506, 867526459, 125872204, 2823501223, 315675245,
664493829, 1193647001, 2902435491, 4008302996, 3293832381,
2797902120, 921910079, 3875536258, 3860489956, 2859620526,
4069365937, 3292188807, 1138527378, 1843896811, 907091389,
2178070895, 1184356041, 2913034589, 533879758, 1047873800,
4089462894, 468348870, 1858800455, 2604678205, 2354182895,
4037991346, 3209323702, 3533093419, 1919473510, 2793616847,
1700883469, 1788826146, 4071958955, 3386785762, 31514172,
19250776, 2584513890, 3387325066, 2742147133, 3944682594,
592591404, 2791260675, 2659746612, 315474544, 1904698953,
2604266677, 353504318, 1964074412, 3514485156, 3169532567,
2790485967, 1849610563, 1030849375, 1970983638, 988474342,
3618121941, 2878575277, 2344105764, 2433708147, 3743911834,
4228347785, 2663980549, 198679867, 251346502, 3680098503,
717982583, 1991968231, 1966215759, 4175925001, 3436559417,
2842484640, 1785986318, 3282130015, 4209173822, 3214474404,
2458053183, 566868871, 4247691298, 1194533791, 413487128,
2956434281, 1976050740, 2603640364, 4002305523, 2189933609,
2235687021, 4111498140, 362578657, 3802239775, 1449348735,
3201915429, 1621039138, 1260879657, 3685433319, 3748224261,
515189256, 76609166, 4023105783, 1332479920, 1941260488,
3789521841, 2649225611, 3982178583, 2640343581, 887753876,
2623879481, 654393789, 3901261877, 4129195216, 2072705798,
3926013998, 134872244, 3851146837, 949460971, 3028829716,
2748465327, 658633327, 2567012781, 1903893367, 4261900666,
3594114567, 3153823335, 1395439952, 2105875768, 381976650,
3596252054, 2194541858, 1669787451, 1574190922, 2367623075,
2588833298, 15549328, 3121584353, 1205559863, 4005798159,
385823106, 1512495320, 2519004920, 3116316035, 1356388453,
2213505697, 3940907241, 662747692, 4271952413, 2238060344,
2185588921, 1133176214, 2634200703, 903739507, 1276798518,
1983926732, 2154820370, 667506551, 1019684315, 2466082561,
470204603, 1310064920, 1475141280, 245892548, 3605779030,
2324199318, 3038548256, 2392772418, 3434801095, 1665423914,
3361599793, 991532629, 203341667, 2271581211, 342113848,
1314742753, 2606385127, 4012521188, 2525627142, 29727465,
1440846172, 3385301785, 1416716301, 3139018885, 1029852703,
4028502064, 3017404021, 2410613923, 3596475680, 2001792110,
2708833257, 14648227, 1907913989, 1170726985, 3111440312,
846965399, 3980136833, 588455565, 4078451876, 3074024037,
2462917040, 2116906947, 3934436462, 2404005766, 326965248,
1973564508, 247294817, 3955932918, 2643613999, 250186407,
1225615873, 1839696651, 1566941656, 3381265620, 2592892272,
1941384363, 3050443439, 3399139892, 1559644122, 3375602111,
1882461001, 687452999, 1468410266, 2806653366, 1904450878,
31291347, 2419739254, 1417596946, 1640466810, 1930448654,
2936089149, 2312610571, 251748476, 437649019, 176021521,
3372211891, 3436408750, 2737092954, 2371839491, 3899285718,
1432435164, 1150021701, 294106891, 3716975008, 925497236,
172267418, 3075465779, 601651008, 3068156317, 478671785,
541454057, 3549429635, 3474732129, 3109897059, 2882340945,
413340895, 1445184158, 659367875, 2246342184, 2990547884,
612017351, 1742799743, 378185270, 144127691, 3292083012,
257622965, 2144914484, 3497476539, 4118294021, 1412081876,
1911936809, 216408794, 4261218624, 3956031618, 3845031202,
4158384247, 3583716646, 2226653414, 3830924459, 2579468707,
2114341243, 1811121245, 1825658109, 938532476, 156167731,
3643681680, 3305159196, 2873238718, 2179049820, 1445742370,
2797236726, 2573165619, 2478500615, 1701106298, 457138435,
21172895, 3301268408, 750845870, 363451077, 1772226286,
3727614291, 956968052, 250648236, 2185516024, 3583558949,
781299039, 2536436898, 2114876194, 1369923866], dtype=uint32),
623,
0,
0.0)
Non-legacy state:
{'bit_generator': 'MT19937',
'gauss': 0.0,
'has_gauss': 0,
'state': {'key': array([2147483648, 2075331599, 140858681, 3526623561, 2541108888,
2106811263, 3519418634, 407860018, 2249244654, 2606184075,
2786589483, 1437634829, 3069487802, 2325528976, 2221173448,
4175430749, 60802753, 3831120806, 3576967720, 1467437066,
3102226502, 1109303602, 4202242805, 3948013612, 3365853984,
1662141710, 665954637, 3930982131, 573623358, 3534123242,
1163249977, 1484804157, 2526724740, 4041334237, 303997122,
467039403, 2604812076, 2662108352, 2700590779, 3149658310,
475663908, 2698417034, 1437811983, 2453274244, 3934471757,
1331634643, 3476886289, 1214548185, 2557676973, 841994067,
1486750722, 1743267989, 2558066518, 1577953441, 2313272813,
553286535, 1808628211, 2154155001, 1626207377, 3930160892,
2129127017, 2082629614, 3872620351, 2201093007, 2516576609,
1033987357, 2408204758, 15660947, 104784592, 59865213,
186913512, 2677884128, 2060217906, 2669113803, 2011873750,
208427217, 3680958504, 2179926499, 3166739308, 1431064049,
2975913278, 1687432296, 1015884945, 1679512422, 4236327727,
789327648, 2432065914, 3711318603, 2856411391, 3648323176,
2833858325, 2966091895, 3017587410, 2267289704, 1257721755,
1729437608, 1424227325, 2923596926, 1826322183, 368372361,
1488204762, 1945476249, 562222798, 493888013, 2257447923,
378468050, 2023039273, 3634128929, 2212812842, 3736007970,
3181848715, 3253178714, 2247387434, 630298605, 168071064,
2453990713, 2754709023, 158309059, 1693133287, 425567776,
610244961, 1099495700, 1651920650, 4001234624, 4215841990,
1597420538, 1371783983, 782335516, 955314030, 2124985108,
1230992899, 2554470207, 178356559, 1538520927, 3123923199,
98816670, 19339993, 1211323375, 2192833935, 3550014245,
4079832419, 763245682, 3114649865, 637250886, 2346230918,
1613897166, 900697267, 717509049, 3148396806, 860799586,
2074079507, 3061186827, 2657283036, 1289050313, 224279970,
2147563708, 1342253071, 1099805121, 1986459658, 451923538,
4151091570, 4093933368, 1531334554, 2482301065, 3017676208,
1552415979, 4080116605, 1391877515, 3326970419, 2087771436,
881497654, 409221337, 2336918916, 3495750864, 505102391,
87044493, 1374578224, 2345228469, 2677354761, 1686087426,
350682568, 3896700142, 3387441124, 1322844580, 2035232113,
4171743547, 2761364203, 87729543, 1543233909, 2006521429,
1975119445, 37244549, 2381364078, 1787348020, 1892866190,
4158911646, 2832543291, 2923372499, 933489775, 1999573074,
1532857763, 3880893325, 2768633145, 1953348816, 3266441215,
2573231980, 2998278938, 2040706335, 3623219627, 3178406798,
4283796809, 1300309756, 1758157157, 4250966860, 2653105583,
3650054898, 2689840365, 3155536137, 1503944792, 1147189469,
3536335302, 3769476577, 2815860922, 4078024620, 2513323216,
1902584155, 4129754107, 4152359938, 226115604, 851253201,
4280862308, 180287939, 2460657108, 2761413323, 1164459732,
4138592784, 1128464012, 2058279714, 28513951, 138939637,
1445273851, 2902445187, 1819904997, 4126426761, 2232230179,
1902241997, 1982429477, 3030253536, 2390946236, 836566848,
2165522031, 2843209162, 1133455873, 547623528, 207067323,
3740164691, 2956947873, 3885177903, 3602466789, 2265740401,
2685171170, 2058459229, 95039846, 2428394435, 4085737081,
1250884434, 2137775703, 3615871800, 736326573, 1908407987,
3232375302, 4025275217, 3344578052, 419050246, 539900722,
361845484, 2024236041, 1807301781, 2207338269, 1752779524,
1862418620, 1495823156, 1133973867, 1296654836, 2940007730,
949464795, 1328598876, 2353612020, 1256954613, 2085080443,
3735562356, 4031242042, 2376353226, 1227464517, 3957619112,
1276759908, 169813089, 2554310499, 2362132599, 734463866,
2567295435, 3713429466, 1660269273, 2342554397, 1944535609,
2590579477, 2663485021, 3446727181, 1776368335, 4209700934,
325303151, 1020536745, 1713248288, 1413553117, 1011704881,
3010455506, 867526459, 125872204, 2823501223, 315675245,
664493829, 1193647001, 2902435491, 4008302996, 3293832381,
2797902120, 921910079, 3875536258, 3860489956, 2859620526,
4069365937, 3292188807, 1138527378, 1843896811, 907091389,
2178070895, 1184356041, 2913034589, 533879758, 1047873800,
4089462894, 468348870, 1858800455, 2604678205, 2354182895,
4037991346, 3209323702, 3533093419, 1919473510, 2793616847,
1700883469, 1788826146, 4071958955, 3386785762, 31514172,
19250776, 2584513890, 3387325066, 2742147133, 3944682594,
592591404, 2791260675, 2659746612, 315474544, 1904698953,
2604266677, 353504318, 1964074412, 3514485156, 3169532567,
2790485967, 1849610563, 1030849375, 1970983638, 988474342,
3618121941, 2878575277, 2344105764, 2433708147, 3743911834,
4228347785, 2663980549, 198679867, 251346502, 3680098503,
717982583, 1991968231, 1966215759, 4175925001, 3436559417,
2842484640, 1785986318, 3282130015, 4209173822, 3214474404,
2458053183, 566868871, 4247691298, 1194533791, 413487128,
2956434281, 1976050740, 2603640364, 4002305523, 2189933609,
2235687021, 4111498140, 362578657, 3802239775, 1449348735,
3201915429, 1621039138, 1260879657, 3685433319, 3748224261,
515189256, 76609166, 4023105783, 1332479920, 1941260488,
3789521841, 2649225611, 3982178583, 2640343581, 887753876,
2623879481, 654393789, 3901261877, 4129195216, 2072705798,
3926013998, 134872244, 3851146837, 949460971, 3028829716,
2748465327, 658633327, 2567012781, 1903893367, 4261900666,
3594114567, 3153823335, 1395439952, 2105875768, 381976650,
3596252054, 2194541858, 1669787451, 1574190922, 2367623075,
2588833298, 15549328, 3121584353, 1205559863, 4005798159,
385823106, 1512495320, 2519004920, 3116316035, 1356388453,
2213505697, 3940907241, 662747692, 4271952413, 2238060344,
2185588921, 1133176214, 2634200703, 903739507, 1276798518,
1983926732, 2154820370, 667506551, 1019684315, 2466082561,
470204603, 1310064920, 1475141280, 245892548, 3605779030,
2324199318, 3038548256, 2392772418, 3434801095, 1665423914,
3361599793, 991532629, 203341667, 2271581211, 342113848,
1314742753, 2606385127, 4012521188, 2525627142, 29727465,
1440846172, 3385301785, 1416716301, 3139018885, 1029852703,
4028502064, 3017404021, 2410613923, 3596475680, 2001792110,
2708833257, 14648227, 1907913989, 1170726985, 3111440312,
846965399, 3980136833, 588455565, 4078451876, 3074024037,
2462917040, 2116906947, 3934436462, 2404005766, 326965248,
1973564508, 247294817, 3955932918, 2643613999, 250186407,
1225615873, 1839696651, 1566941656, 3381265620, 2592892272,
1941384363, 3050443439, 3399139892, 1559644122, 3375602111,
1882461001, 687452999, 1468410266, 2806653366, 1904450878,
31291347, 2419739254, 1417596946, 1640466810, 1930448654,
2936089149, 2312610571, 251748476, 437649019, 176021521,
3372211891, 3436408750, 2737092954, 2371839491, 3899285718,
1432435164, 1150021701, 294106891, 3716975008, 925497236,
172267418, 3075465779, 601651008, 3068156317, 478671785,
541454057, 3549429635, 3474732129, 3109897059, 2882340945,
413340895, 1445184158, 659367875, 2246342184, 2990547884,
612017351, 1742799743, 378185270, 144127691, 3292083012,
257622965, 2144914484, 3497476539, 4118294021, 1412081876,
1911936809, 216408794, 4261218624, 3956031618, 3845031202,
4158384247, 3583716646, 2226653414, 3830924459, 2579468707,
2114341243, 1811121245, 1825658109, 938532476, 156167731,
3643681680, 3305159196, 2873238718, 2179049820, 1445742370,
2797236726, 2573165619, 2478500615, 1701106298, 457138435,
21172895, 3301268408, 750845870, 363451077, 1772226286,
3727614291, 956968052, 250648236, 2185516024, 3583558949,
781299039, 2536436898, 2114876194, 1369923866], dtype=uint32),
'pos': 623}}
There was a problem hiding this comment.
okay, I see, it's for legacy compatibility. There is the option of breaking this also, but if you think this is the way to go, I'm happy to merge like this :)
There was a problem hiding this comment.
I think this is sufficient for now since this PR is for type safety, and if we were to change the structure of the random state, it should be addressed in a separate PR. (And since it's functionality for saving and loading JSON files, we need to consider loading as well (such as adding legacy compatibility logic, etc.))
set_acquisition_paramsreceives only a single variable named params, not keyword-only arguments. (I inferred this from methods in classes that inherit and implement it.)random_samplereturns a list whose elements are dicts, not a dict.save_state, when using numpy's random_state, I explicitly passed the argument to avoid using the legacy behavior (legacy=False).dicttodict[str, Any]. (From the actual code, I assumed the keys are always strings.)Summary by CodeRabbit
random_sample()method now returns a list of parameter sets instead of a single set.